import os

from data_analysis_agent import LLMConfig, DataAnalysisAgent
from data_analysis_agent.utils import LLMHelper
from data_analysis_agent.utils.prompts import user_input

company_profile = "./company_profile/"
analysys_files = [os.path.join(company_profile, "百度利润表.csv"),
                  os.path.join(company_profile, "百度现金流量表.csv"),
                  os.path.join(company_profile, "百度资产负债表.csv")
                  ]
def main():
    llm_config = LLMConfig()
    agent = DataAnalysisAgent(llm_config)
    user_input1= "基于以下公司的数据，输出5个重要的统计指标，并绘制相关图表。最后生成汇报给我。"
    report = agent.analyze(user_input=user_input1,
                           files=analysys_files)
    print(report)

def test():
    llm = LLMHelper(LLMConfig())
    prompt = "这几张表是关于哪一个公司的？请直接返回公司名称即可。" + "\n".join(analysys_files)
    company_name = llm.call(prompt, temperature=0.05)
    print(company_name)
    
if __name__ == "__main__":
    main()
    # test()
    